#!/usr/bin/env python
# -*- coding:utf-8 -*-

from libs import printer
from libs import broute
from libs.mitmSql import mitm_get_req_info_by_url
from libs import tools
from conf.setting import DICT_PATH
import os
import traceback

def get_files(fDir,suf):
    '''
    @des
    获取指定目录下指定文件格式的文件名
    suf为空，则获取完整文件名
    '''
    result = []
    for _,_,files in os.walk(fDir):
        for f in files:
            name,suffix = os.path.splitext(f)
            if suf:
                if suffix == suf:
                    result.append(name)
            else:
                result.append(f)
    return result
def get_req(url):
    r = mitm_get_req_info_by_url("mitmhttp",url)
    result = {}
    if r:
        result = r
    return result

def get_dict(dname):
    dfiles = get_files(DICT_PATH,"")
    result = ''
    if dfiles:
        if dname in dfiles:
            result = DICT_PATH + os.sep + dname
    return result

def show_req(result:dict):
    printer.info("请求信息如下: ")
    printer.plus("URL : " + result['url'])
    printer.plus("headers[头信息] : " + str(result['headers']))
    printer.plus("params[参数] : " + result['params'])
def show_dict():
    printer.info("收集的字典如下: ")
    dfiles = get_files(DICT_PATH,"")
    if dfiles:
        for d in dfiles:
            printer.plus(d)
    else:
        printer.warn("无字典")

def show_modify_menu():
    printer.info("可修改的选项: ")
    printer.plus("1. URL")
    printer.plus("2. headers[头信息]")
    printer.plus("3. params[参数]")
def show_test_mode_menu():
    printer.info("测试模式选项: ")
    printer.plus("1. 普通蛮力模式")
    printer.plus("2. 单点蛮力模式")
    printer.plus("3. 草叉蛮力模式")
    printer.plus("4. 集束蛮力模式")

def url_replace_tag(url:str,word,tag="$"):
    rurl = ''
    if tools.judge_tag(url,word):
        rurl = url.replace(word,tag)
    return rurl

def judge_tag_headers(headers:dict,tag):
    if headers:
        for k,v in headers.items():
            if tag in k or tag in str(v):
                return True
    return False

def headers_replace_tag(headers:dict,word,tag="$"):
    rheaders = {}
    if judge_tag_headers(headers,word):
        rheaders = tools.replace_tag(headers,tag,word)
    return rheaders
def pitchfork_replace_headers(headers:dict,word,tag="$"):
    rheaders = {}
    rheaders = tools.replace_tag(headers,tag,word)
    return rheaders    
def params_replace_tag(params:str,word,tag="$"):
    rparams = ''
    if tools.judge_tag(params,word):
        rparams = params.replace(word,tag)
    return rparams
def start_loop(url,words,verb,limit,delay,headers,data,timeout,hc,sc,ht,st,tag,orireq,record,mode):
    b = broute.broute(url,words,verb,limit,delay,headers,data,timeout,hc,sc,ht,st,tag,orireq,record,mode)
    b.loop()
def judge_str(d:str,s:list):
    sign = True
    for w in s:
        if d.find(w) == -1:
            sign = False
            break
    return sign
def judge_dict(d:dict,s:list):
    sign = True
    count = 0
    for w in s:
        for k,v in d.items():
            if str(k).find(w) >= 0:
                count += 1
            if str(v).find(w) >= 0:
                count += 1
    if count != len(s):
        sign = False
    return sign
    
def sniper_mode(url):
    req = get_req(url)
    hc = sc = []
    ht = st = None
    limit = 1
    delay = 1
    timeout = 3
    oReq = {}
    record = {}
    sign = False
    dicts = []
    if req:
        show_req(req)
        show_modify_menu()
        choice = input("选项: ")
        if choice == str(1):
            w = input("注入点($连接符): ")
            if w.find("$") > 0:
                w = w.strip()
                ws = w.split("$")
                if judge_str(url,ws):
                    oReq = req.copy()
                    record['url'] = {}
                    for i in range(0,len(ws)):
                        k = "$" + str(i) + "$"
                        v = ws[i]
                        record['url'][k] = v
                    sign = True
        elif choice == str(2):
            w = input("注入点($连接符): ")
            if w.find("$") > 0:
                w = w.strip()
                ws = w.split("$")
                if judge_dict(req['headers'],ws):
                    oReq = req.copy()
                    record['headers'] = {}
                    for i in range(0,len(ws)):
                        k = "$" + str(i) + "$"
                        v = ws[i]
                        record['headers'][k] = v
                    sign = True
        elif choice == str(3):
            w = input("注入点($连接符): ")
            if w.find("$") > 0:
                w = w.strip()
                ws = w.split("$")
                if judge_str(req['params'],ws):
                    oReq = req.copy()
                    record['params'] = {}
                    for i in range(0,len(ws)):
                        k = "$" + str(i) + "$"
                        v = ws[i]
                        record['params'][k] = v
                    sign = True                                                
        if sign:
            ihc = input("隐藏的状态码[可选]: ")
            isc = input("显示的状态码[可选]: ")
            iht = input("含此值则显示[可选]: ")
            ist = input("含此值则忽略[可选]: ")
            ili = input("请输入协程数[可选]: ")
            ide = input("输入间隔时间[可选]: ")
            if ihc:
                hc.append(ihc)
            if isc:
                sc.append(isc)
            if iht:
                ht = iht
            if ist:
                st = ist
            if ili:
                limit = int(ili)
            if ide:
                delay = int(ide)
            show_dict()
            idict = input("选择字典[单选]: ")
            idict = idict.strip()
            idict = get_dict(idict)
            if idict:
                dicts.append(idict)
                start_loop(req['url'],dicts,req['method'],limit,delay,req['headers'],req['params'],timeout,hc,sc,ht,st,"$",oReq,record,2)
            else:
                printer.warn("字典加载错误!")
        else:
            printer.warn("选择错误!")
    else:
        printer.warn("请求不存在!")
def pitchfork_mode(url):
    req = get_req(url)
    hc = []
    sc = []
    ht = None
    st = None
    limit = 1
    delay = 1
    timeout = 3
    oReq = {}
    record = {}
    sign = False
    dicts = []
    if req:
        show_req(req)
        show_modify_menu()
        choice = input("选项: ")
        choice = choice.strip()
        if choice == str(1):
            w = input("注入点($连接符): ")
            if w.find("$") > 0:
                w = w.strip()
                ws = w.split("$")
                if judge_str(url,ws):
                    oReq = req.copy()
                    tUrl = url
                    record['url'] = {}
                    for i in range(0,len(ws)):
                        tag = '$' + str(i) + '$'
                        tUrl = url_replace_tag(tUrl,ws[i],tag=tag)
                        record['url'][tag] = ws[i]
                    req['url'] = tUrl
                    sign = True
        elif choice == str(2):
            w = input("注入点($连接符): ")
            if w.find("$") > 0:
                w = w.strip()
                ws = w.split("$")
                if judge_dict(req['headers'],ws):
                    oReq = req.copy()
                    record['headers'] = {}
                    nHeaders = req['headers']
                    for i in range(0,len(ws)):
                        k = "$" + str(i) + "$"
                        v = ws[i]
                        record['headers'][k] = v
                        nHeaders = pitchfork_replace_headers(nHeaders,k,tag=v)
                    req['headers'] = nHeaders
                    sign = True            
        elif choice == str(3):
            w = input("注入点($连接符): ")
            if w.find("$") > 0:
                w = w.strip()
                ws = w.split("$")
                if judge_str(req['params'],ws):
                    oReq = req.copy()
                    record['params'] = {}
                    nParams = req['params']
                    for i in range(0,len(ws)):
                        k = "$" + str(i) + "$"
                        v = ws[i]
                        record['params'][k] = v
                        nParams = params_replace_tag(nParams,v,tag=k)
                    req['params'] = nParams
                    sign = True
        if sign:
            ihc = input("隐藏的状态码[可选]: ")
            isc = input("显示的状态码[可选]: ")
            iht = input("含此值则显示[可选]: ")
            ist = input("含此值则忽略[可选]: ")
            ili = input("请输入协程数[可选]: ")
            ide = input("输入间隔时间[可选]: ")
            if ihc:
                hc.append(ihc)
            if isc:
                sc.append(isc)
            if iht:
                ht = iht
            if ist:
                st = ist
            if ili:
                limit = int(ili)
            if ide:
                delay = int(ide)
            show_dict()
            idict = input("选择字典[0$字典名,1$...]: ")
            idict = idict.strip()
            ids = []
            if ',' in idict:
                ids = idict.split(",")
            else:
                printer.warn("字典格式错误!")
                raise
            if 'url' in record.keys():
                if len(ids) != len(record['url'].keys()):
                    printer.warn("参数与字典个数需对应!")
                    raise
            elif 'headers' in record.keys():
                if len(ids) != len(record['headers'].keys()):
                    printer.warn("参数与字典个数需对应!")
                    raise
            elif 'params' in record.keys():
                if len(ids) != len(record['params'].keys()):
                    printer.warn("参数与字典个数需对应!")
                    raise
            if ids:
                for i in ids:
                    id,n = i.split("$")
                    idict = get_dict(n)
                    if idict:
                        fdict = id + "$" + idict
                        dicts.append(fdict)
                    else:
                        printer.warn("字典加载错误!")
                        raise
            start_loop(req['url'],dicts,req['method'],limit,delay,req['headers'],req['params'],timeout,hc,sc,ht,st,"$",oReq,record,3)
        else:
            printer.warn("选择错误!")
    else:
        printer.warn("请求不存在!")
def cluster_mode(url):
    req = get_req(url)
    hc = []
    sc = []
    ht = None
    st = None
    limit = 1
    delay = 1
    timeout = 3
    oReq = {}
    record = {}
    sign = False
    dicts = []
    if req:
        show_req(req)
        show_modify_menu()
        choice = input("选项: ")
        choice = choice.strip()
        if choice == str(1):
            w = input("注入点($连接符): ")
            if w.find("$") > 0:
                w = w.strip()
                ws = w.split("$")
                if judge_str(url,ws):
                    oReq = req.copy()
                    tUrl = url
                    record['url'] = {}
                    for i in range(0,len(ws)):
                        tag = '$' + str(i) + '$'
                        tUrl = url_replace_tag(tUrl,ws[i],tag=tag)
                        record['url'][tag] = ws[i]
                    req['url'] = tUrl
                    sign = True
        elif choice == str(2):
            w = input("注入点($连接符): ")
            if w.find("$") > 0:
                w = w.strip()
                ws = w.split("$")
                if judge_dict(req['headers'],ws):
                    oReq = req.copy()
                    record['headers'] = {}
                    nHeaders = req['headers']
                    for i in range(0,len(ws)):
                        k = "$" + str(i) + "$"
                        v = ws[i]
                        record['headers'][k] = v
                        nHeaders = pitchfork_replace_headers(nHeaders,k,tag=v)
                    req['headers'] = nHeaders
                    sign = True            
        elif choice == str(3):
            w = input("注入点($连接符): ")
            if w.find("$") > 0:
                w = w.strip()
                ws = w.split("$")
                if judge_str(req['params'],ws):
                    oReq = req.copy()
                    record['params'] = {}
                    nParams = req['params']
                    for i in range(0,len(ws)):
                        k = "$" + str(i) + "$"
                        v = ws[i]
                        record['params'][k] = v
                        nParams = params_replace_tag(nParams,v,tag=k)
                    req['params'] = nParams
                    sign = True
        if sign:
            ihc = input("隐藏的状态码[可选]: ")
            isc = input("显示的状态码[可选]: ")
            iht = input("含此值则显示[可选]: ")
            ist = input("含此值则忽略[可选]: ")
            ili = input("请输入协程数[可选]: ")
            ide = input("输入间隔时间[可选]: ")
            if ihc:
                hc.append(ihc)
            if isc:
                sc.append(isc)
            if iht:
                ht = iht
            if ist:
                st = ist
            if ili:
                limit = int(ili)
            if ide:
                delay = int(ide)
            show_dict()
            idict = input("选择字典[0$字典名,1$...]: ")
            idict = idict.strip()
            ids = []
            if ',' in idict:
                ids = idict.split(",")
            else:
                printer.warn("字典格式错误!")
                raise
            if 'url' in record.keys():
                if len(ids) != len(record['url'].keys()):
                    printer.warn("参数与字典个数需对应!")
                    raise
            elif 'headers' in record.keys():
                if len(ids) != len(record['headers'].keys()):
                    printer.warn("参数与字典个数需对应!")
                    raise
            elif 'params' in record.keys():
                if len(ids) != len(record['params'].keys()):
                    printer.warn("参数与字典个数需对应!")
                    raise
            if ids:
                for i in ids:
                    id,n = i.split("$")
                    idict = get_dict(n)
                    if idict:
                        fdict = id + "$" + idict
                        dicts.append(fdict)
                    else:
                        printer.warn("字典加载错误!")
                        raise
            start_loop(req['url'],dicts,req['method'],limit,delay,req['headers'],req['params'],timeout,hc,sc,ht,st,"$",oReq,record,4)
        else:
            printer.warn("选择错误!")
    else:
        printer.warn("请求不存在!")   
def normal_mode(url):
    req = get_req(url)
    hc = sc = []
    ht = st = None
    limit = 1
    delay = 1
    timeout = 3
    oReq = {}
    sign = False
    if req:
        show_req(req)
        show_modify_menu()
        choice = input("选项: ")
        choice = choice.strip()
        if choice == str(1):
            w = input("注入点: ")
            w = w.strip()
            if w:
                oReq = req.copy()
                nUrl = url_replace_tag(url,w)
                if nUrl:
                    req['url'] = nUrl
                    sign = True               
        elif choice == str(2):
            w = input("注入点: ")
            w = w.strip()
            if w:
                oReq = req.copy()
                nHeaders = headers_replace_tag(req['headers'],w)
                if nHeaders:
                    req['headers'] = nHeaders
                    sign = True 
        elif choice == str(3):               
            w = input("注入点: ")
            w = w.strip()
            if w:
                oReq = req.copy()
                nParams = params_replace_tag(req['params'],w)
                if nParams:
                    req['params'] = nParams
                    sign = True
        if sign:
            ihc = input("隐藏的状态码[可选]: ")
            isc = input("显示的状态码[可选]: ")
            iht = input("含此值则显示[可选]: ")
            ist = input("含此值则忽略[可选]: ")
            ili = input("请输入协程数[可选]: ")
            ide = input("输入间隔时间[可选]: ")
            if ihc:
                hc.append(ihc)
            if isc:
                sc.append(isc)
            if iht:
                ht = iht
            if ist:
                st = ist
            if ili:
                limit = int(ili)
            if ide:
                delay = int(ide)
            show_dict()
            idict = input("选择字典: ")
            idict = idict.strip()
            idict = get_dict(idict)
            if idict:
                start_loop(req['url'],idict,req['method'],limit,delay,req['headers'],req['params'],timeout,hc,sc,ht,st,"$",oReq,{},1)
            else:
                printer.warn("字典加载错误!")
        else:
            printer.warn("选择错误!")
    else:
        printer.warn("请求不存在!")  
def execute(args):
    try:
        printer.warn("异步协程实现的蛮力工具，后续添加更多处理逻辑!")
        url = input("请输入url: ")
        url = url.strip()
        show_test_mode_menu()
        mode = input("选择: ")
        if int(mode) == 1:
            normal_mode(url)
        elif int(mode) == 2:
            sniper_mode(url)
        elif int(mode) == 3:
            pitchfork_mode(url)
        elif int(mode) == 4:
            cluster_mode(url)
    except Exception as e:
        traceback.print_exc()